{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 程序说明\n",
    "时间：2016年11月16日\n",
    "\n",
    "说明：该程序是一个包含LSTM的神经网络。\n",
    "\n",
    "数据集：MNIST"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1.加载keras模块"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "from keras.models import Sequential\n",
    "from keras.layers import LSTM, Dense\n",
    "from keras.datasets import mnist\n",
    "from keras.utils import np_utils\n",
    "from keras import initializers\n",
    "\n",
    "def init_weights(shape, name=None):\n",
    "    return initializers.normal(shape, scale=0.01, name=name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 如需绘制模型请加载plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from keras.utils.vis_utils import plot_model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.变量初始化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Hyper parameters\n",
    "batch_size = 128\n",
    "nb_epoch = 10\n",
    "\n",
    "# Parameters for MNIST dataset\n",
    "img_rows, img_cols = 28, 28\n",
    "nb_classes = 10\n",
    "\n",
    "# Parameters for LSTM network\n",
    "nb_lstm_outputs = 30\n",
    "nb_time_steps = img_rows\n",
    "dim_input_vector = img_cols"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3.准备数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('X_train original shape:', (60000, 28, 28))\n",
      "('X_train shape:', (60000, 28, 28))\n",
      "(60000, 'train samples')\n",
      "(10000, 'test samples')\n"
     ]
    }
   ],
   "source": [
    "# Load MNIST dataset\n",
    "(X_train, y_train), (X_test, y_test) = mnist.load_data()\n",
    "print('X_train original shape:', X_train.shape)\n",
    "input_shape = (nb_time_steps, dim_input_vector)\n",
    "\n",
    "X_train = X_train.astype('float32') / 255.\n",
    "X_test = X_test.astype('float32') / 255.\n",
    "Y_train = np_utils.to_categorical(y_train, nb_classes)\n",
    "Y_test = np_utils.to_categorical(y_test, nb_classes)\n",
    "\n",
    "print('X_train shape:', X_train.shape)\n",
    "print(X_train.shape[0], 'train samples')\n",
    "print(X_test.shape[0], 'test samples')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.建立模型\n",
    "### 使用Sequential（）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build LSTM network\n",
    "model = Sequential()\n",
    "model.add(LSTM(nb_lstm_outputs, input_shape=input_shape))\n",
    "model.add(Dense(nb_classes, kernel_initializer=initializers.random_normal(stddev=0.01),activation='softmax'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 打印模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "lstm_1 (LSTM)                (None, 30)                7080      \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 10)                310       \n",
      "=================================================================\n",
      "Total params: 7,390\n",
      "Trainable params: 7,390\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 绘制模型结构图，并保存成图片"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "ename": "ImportError",
     "evalue": "Failed to import pydot. You must install pydot and graphviz for `pydotprint` to work.",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mImportError\u001b[0m                               Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-9-87c2bcbfc5a7>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mplot_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mto_file\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'lstm_model.png'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m/opt/anaconda/lib/python2.7/site-packages/keras/utils/vis_utils.pyc\u001b[0m in \u001b[0;36mplot_model\u001b[0;34m(model, to_file, show_shapes, show_layer_names, rankdir)\u001b[0m\n\u001b[1;32m    129\u001b[0m             \u001b[0;34m'LR'\u001b[0m \u001b[0mcreates\u001b[0m \u001b[0ma\u001b[0m \u001b[0mhorizontal\u001b[0m \u001b[0mplot\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    130\u001b[0m     \"\"\"\n\u001b[0;32m--> 131\u001b[0;31m     \u001b[0mdot\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel_to_dot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mshow_shapes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mshow_layer_names\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrankdir\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    132\u001b[0m     \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mextension\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplitext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mto_file\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    133\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mextension\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/anaconda/lib/python2.7/site-packages/keras/utils/vis_utils.pyc\u001b[0m in \u001b[0;36mmodel_to_dot\u001b[0;34m(model, show_shapes, show_layer_names, rankdir)\u001b[0m\n\u001b[1;32m     50\u001b[0m     \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodels\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mSequential\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     51\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 52\u001b[0;31m     \u001b[0m_check_pydot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     53\u001b[0m     \u001b[0mdot\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpydot\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     54\u001b[0m     \u001b[0mdot\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'rankdir'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrankdir\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/anaconda/lib/python2.7/site-packages/keras/utils/vis_utils.pyc\u001b[0m in \u001b[0;36m_check_pydot\u001b[0;34m()\u001b[0m\n\u001b[1;32m     25\u001b[0m         \u001b[0;31m# pydot raises a generic Exception here,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     26\u001b[0m         \u001b[0;31m# so no specific class can be caught.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 27\u001b[0;31m         raise ImportError('Failed to import pydot. You must install pydot'\n\u001b[0m\u001b[1;32m     28\u001b[0m                           ' and graphviz for `pydotprint` to work.')\n\u001b[1;32m     29\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mImportError\u001b[0m: Failed to import pydot. You must install pydot and graphviz for `pydotprint` to work."
     ]
    }
   ],
   "source": [
    "plot_model(model, to_file='lstm_model.png')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 显示绘制的图片\n",
    "![image](http://p1.bqimg.com/4851/08854414148a36b0.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5.训练与评估\n",
    "### 编译模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 迭代训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/anaconda/lib/python2.7/site-packages/keras/models.py:848: UserWarning: The `nb_epoch` argument in `fit` has been renamed `epochs`.\n",
      "  warnings.warn('The `nb_epoch` argument in `fit` '\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      "60000/60000 [==============================] - 22s - loss: 2.2854 - acc: 0.1381    \n",
      "Epoch 2/10\n",
      "60000/60000 [==============================] - 21s - loss: 2.2874 - acc: 0.1296    \n",
      "Epoch 3/10\n",
      "60000/60000 [==============================] - 21s - loss: 2.3018 - acc: 0.1118    \n",
      "Epoch 4/10\n",
      "60000/60000 [==============================] - 21s - loss: 2.3018 - acc: 0.1118    \n",
      "Epoch 5/10\n",
      "60000/60000 [==============================] - 21s - loss: 2.3019 - acc: 0.1115    \n",
      "Epoch 6/10\n",
      "60000/60000 [==============================] - 21s - loss: 2.3018 - acc: 0.1119    \n",
      "Epoch 7/10\n",
      "60000/60000 [==============================] - 21s - loss: 2.3018 - acc: 0.1119    \n",
      "Epoch 8/10\n",
      "60000/60000 [==============================] - 21s - loss: 2.3018 - acc: 0.1121    \n",
      "Epoch 9/10\n",
      "60000/60000 [==============================] - 21s - loss: 2.3018 - acc: 0.1119    \n",
      "Epoch 10/10\n",
      "60000/60000 [==============================] - 21s - loss: 2.3018 - acc: 0.1120    \n"
     ]
    }
   ],
   "source": [
    "history = model.fit(X_train, Y_train, epochs=nb_epoch, batch_size=batch_size, shuffle=True, verbose=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 模型评估"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10000/10000 [==============================] - 5s     \n",
      "('Test score:', 0.11906883909329773)\n",
      "('Test accuracy:', 0.96530000000000005)\n"
     ]
    }
   ],
   "source": [
    "score = model.evaluate(X_test, Y_test, verbose=1)\n",
    "print('Test score:', score[0])\n",
    "print('Test accuracy:', score[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.14"
  },
  "ssap_exp_config": {
   "error_alert": "Error Occurs!",
   "initial": [],
   "max_iteration": 1000,
   "recv_id": "",
   "running": [],
   "summary": [],
   "version": "1.1.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
